// Copyright 2022 The Mumble Developers. All rights reserved.
// Use of this source code is governed by a BSD-style license
// that can be found in the LICENSE file at the root of the
// Mumble source tree or at <https://www.mumble.info/LICENSE>.

#include "Table.h"
#include "AccessException.h"
#include "ConversionUtils.h"
#include "Database.h"
#include "FormatException.h"
#include "Trigger.h"
#include "Utils.h"

#include "database/Column.h"
#include <soci/soci.h>

#include <boost/algorithm/string.hpp>

#include <nlohmann/json.hpp>

#include <cassert>
#include <utility>
#include <vector>

namespace mumble {
namespace db {

	// If this looks weird to you, check out https://stackoverflow.com/a/8016853
	// (Essentially this is needed to avoid an undefined reference error for this constant)
	constexpr const char *Table::BACKUP_SUFFIX;

	Table::Table(soci::session &sql, Backend backend, Database *database) : Table(sql, backend, {}, {}, database) {}
	Table::Table(soci::session &sql, Backend backend, const std::string &name, const std::vector< Column > &columns,
				 Database *database)
		: m_name(name), m_columns(columns), m_sql(sql), m_backend(backend), m_database(database) {
		performCtorAssertions();
	}

	const std::string &Table::getName() const { return m_name; }
	void Table::setName(const std::string &name) { m_name = name; }

	const std::vector< Column > &Table::getColumns() const { return m_columns; }

	void Table::setColumns(const std::vector< Column > &columns) {
		// we don't support overwriting columns after they have been initialized already
		assert(m_columns.empty());

		m_columns = columns;
	}

	const Column *Table::findColumn(const std::string &name) const {
		for (const Column &currentCol : m_columns) {
			if (currentCol.getName() == name) {
				return &currentCol;
			}
		}

		return nullptr;
	}

	bool Table::containsColumn(const std::string &name) const { return findColumn(name) != nullptr; }

	void Table::create() {
		assert(!m_name.empty());
		assert(!m_columns.empty());

		TransactionHolder transaction = ensureTransaction();

		std::string createQuery = "CREATE TABLE \"" + m_name + "\" (";

		for (const Column &currentColumn : m_columns) {
			createQuery +=
				"\"" + currentColumn.getName() + "\" " + currentColumn.getType().sqlRepresentation(m_backend);

			if (currentColumn.hasDefaultValue()) {
				createQuery += " DEFAULT ";
				std::string defaultValue = currentColumn.getDefaultValue();
				if (currentColumn.getType().isStringType() && defaultValue != "NULL") {
					if (defaultValue != "'NULL'") {
						// Escape single quotes by doubling them up
						boost::replace_all(defaultValue, "'", "''");
					}

					// We'll have to wrap the value in quotes in order to make sure it gets recognized as a String
					createQuery += "'" + defaultValue + "'";
				} else {
					// If the default is not a String-type, we don't want to see spaces in it
					assert(defaultValue.find(" ") == std::string::npos);

					createQuery += defaultValue;
				}
			}

			if (currentColumn.testFlag(Column::Flag::AUTOINCREMENT)) {
				// Assert that the current column is the primary key. Otherwise, we don't support auto-incrementing (as
				// it is not supported across different RDMS).
				assert(hasPrimaryKey());
				assert(!m_primaryKey.isCompositeKey());
				assert(m_primaryKey.getColumnNames()[0] == currentColumn.getName());
				// Auto-increment only makes sense for numeric columns
				assert(currentColumn.getType().getType() == DataType::Integer);

				switch (m_backend) {
					case Backend::SQLite:
						// In SQLite an integer primary key is auto-increment by default. The explicit use of the
						// AUTOINCREMENT keyword is even discouraged.
						break;
					case Backend::MySQL:
						createQuery += " AUTO_INCREMENT";
						break;
					case Backend::PostgreSQL:
						createQuery += " GENERATED BY DEFAULT AS IDENTITY";
						break;
				}
			}

			for (const Constraint &currentConstraint : currentColumn.getConstraints()) {
				createQuery += " " + currentConstraint.sql(m_backend);
			}

			createQuery += ", ";
		}

		if (m_primaryKey.isValid()) {
			createQuery += m_primaryKey.sql() + ", ";

#ifndef NDEBUG
			for (const std::string &name : m_primaryKey.getColumnNames()) {
				const Column *col = findColumn(name);

				assert(col);
				// For historic reasons in SQLite a primary key does not imply NOT NULL (as the SQL standard implies),
				// so we have to make sure to explicitly label the columns in a primary key as NOT NULL. In order to not
				// make things more complicated than they need to be, we simply require this for all backends.
				assert(std::find_if(col->getConstraints().begin(), col->getConstraints().end(),
									[](const Constraint &c) { return c.getType() == Constraint::NotNull; })
					   != col->getConstraints().end());
			}
#endif
		}

		for (const ForeignKey &foreignKey : m_foreignKeys) {
			createQuery += foreignKey.sql() + ", ";
		}

		// Remove trailing ", "
		createQuery.erase(createQuery.size() - 2);

		createQuery += ")";


		// In PostgreSQL, one is instructed to create special triggers for every BLOB column, that call
		// lo_manage, which (presumably) makes sure that the created BLOBs are deleted once they are no longer
		// referenced.
		if (m_backend == Backend::PostgreSQL) {
			for (const Column &currentCol : getColumns()) {
				if (currentCol.getType().getType() != DataType::Blob) {
					continue;
				}
				std::string triggerBody = "EXECUTE PROCEDURE lo_manage(\"" + currentCol.getName() + "\");";

				Trigger updateTrigger = Trigger(currentCol.getName() + "_lo_manage_update_trigger",
												Trigger::Timing::Before, Trigger::Event::Update, triggerBody);
				updateTrigger.setDropBeforeDeleteTable(false);
				Trigger deleteTrigger = Trigger(currentCol.getName() + "_lo_manage_delete_trigger",
												Trigger::Timing::Before, Trigger::Event::Delete, triggerBody);
				deleteTrigger.setDropBeforeDeleteTable(false);

				addTrigger(std::move(updateTrigger), false);
				addTrigger(std::move(deleteTrigger), false);
			}
		}

		try {
			m_sql << createQuery;

			// Also create all necessary indices
			for (const Index &currentIndex : m_indices) {
				m_sql << currentIndex.creationQuery(*this, m_backend);
			}

			// Finally, add triggers
			for (Trigger &currentTrigger : m_trigger) {
				m_sql << currentTrigger.creationQuery(*this, m_backend);

				currentTrigger.setCreated(true);
			}
		} catch (const soci::soci_error &e) {
			throw AccessException(e.what());
		}

		transaction.commit();
	}

	void Table::migrate(unsigned int fromSchemeVersion, unsigned int toSchemeVersion) {
		(void) fromSchemeVersion;
		(void) toSchemeVersion;
		// The default implementation simply imports all data from the old table into the new one. The previously
		// existing table will have been renamed to include the suffix Database::OLD_TABLE_SUFFIX. Other than that the
		// table name and the columns in that table are assumed to be equal to those of the table represented by this
		// class.

		std::string columns;
		for (auto it = m_columns.begin(); it != m_columns.end(); ++it) {
			if (it != m_columns.begin()) {
				columns += ", ";
			}

			columns += "\"" + it->getName() + "\"";
		}

		try {
			m_sql << "INSERT INTO \"" << getName() << "\" (" << columns << ") SELECT " << columns << " FROM \""
				  << getName() << Database::OLD_TABLE_SUFFIX << "\"";
		} catch (const soci::soci_error &e) {
			throw AccessException("Failed at migrating table \"" + getName() + "\": " + e.what());
		}
	}

	void Table::postMigrationAction(unsigned int fromSchemeVersion, unsigned int toSchemeVersion) {
		(void) fromSchemeVersion;
		(void) toSchemeVersion;
	}

	void Table::destroy() {
		assert(!m_name.empty());

		try {
			TransactionHolder transaction = ensureTransaction();

			m_sql << "DROP TABLE \"" << m_name + "\"";

			transaction.commit();
		} catch (const soci::soci_error &e) {
			throw AccessException(e.what());
		}
	}

	void Table::clear() {
		assert(!m_name.empty());

		TransactionHolder transaction = ensureTransaction();

		try {
			// Note that thanks to the missing WHERE clause, this deletes all rows in this table
			m_sql << "DELETE FROM \"" << m_name + "\"";
		} catch (const soci::soci_error &e) {
			throw AccessException(e.what());
		}

		transaction.commit();
	}

	Database *Table::getDatabase() { return m_database; }

	const Database *Table::getDatabase() const { return m_database; }

	void Table::setDatabase(Database *database) { m_database = database; }

	const std::vector< Index > &Table::getIndices() const { return m_indices; }

	void Table::addIndex(const Index &index, bool applyToDB) {
		if (applyToDB) {
			TransactionHolder transaction = ensureTransaction();

			try {
				m_sql << index.creationQuery(*this, m_backend);
			} catch (const soci::soci_error &e) {
				throw AccessException("Failed at creating index \"" + index.getName() + "\": " + e.what());
			}

			transaction.commit();
		}

		m_indices.push_back(index);
	}

	bool Table::removeIndex(const Index &index, bool applyToDB) {
		auto it = std::find(m_indices.begin(), m_indices.end(), index);
		if (it == m_indices.end()) {
			return false;
		}

		m_indices.erase(it);

		if (applyToDB) {
			TransactionHolder transaction = ensureTransaction();

			try {
				m_sql << index.dropQuery(*this, m_backend);
			} catch (const soci::soci_error &e) {
				throw AccessException("Failed at dropping index \"" + index.getName() + "\": " + e.what());
			}

			transaction.commit();
		}

		return true;
	}

	const std::vector< Trigger > &Table::getTrigger() const { return m_trigger; }

	void Table::addTrigger(Trigger trigger, bool applyToDB) {
		if (applyToDB) {
			TransactionHolder transaction = ensureTransaction();

			try {
				m_sql << trigger.creationQuery(*this, m_backend);
			} catch (const soci::soci_error &e) {
				throw AccessException("Failed at creating trigger \"" + trigger.getName() + "\": " + e.what());
			}

			trigger.setCreated(true);

			transaction.commit();
		}

		m_trigger.push_back(std::move(trigger));
	}

	bool Table::removeTrigger(const Trigger &trigger, bool applyToDB) {
		auto it = std::find(m_trigger.begin(), m_trigger.end(), trigger);
		if (it == m_trigger.end()) {
			return false;
		}

		m_trigger.erase(it);

		if (applyToDB) {
			TransactionHolder transaction = ensureTransaction();

			try {
				m_sql << trigger.dropQuery(*this, m_backend);
			} catch (const soci::soci_error &e) {
				throw AccessException("Failed at dropping trigger \"" + trigger.getName() + "\": " + e.what());
			}

			transaction.commit();
		}

		return true;
	}

	bool Table::hasPrimaryKey() const { return m_primaryKey.isValid(); }

	const PrimaryKey &Table::getPrimaryKey() const { return m_primaryKey; }

	void Table::setPrimaryKey(const PrimaryKey &key) { m_primaryKey = key; }

	const std::vector< ForeignKey > &Table::getForeignKeys() const { return m_foreignKeys; }

	void Table::addForeignKey(const ForeignKey &key) {
		assert(std::find(m_foreignKeys.begin(), m_foreignKeys.end(), key) == m_foreignKeys.end());

		m_foreignKeys.push_back(key);
	}

	void Table::removeForeignKey(const ForeignKey &key) {
		auto it = std::find(m_foreignKeys.begin(), m_foreignKeys.end(), key);

		if (it != m_foreignKeys.end()) {
			m_foreignKeys.erase(it);
		}
	}

	void Table::clearForeignKeys() { m_foreignKeys.clear(); }

	TransactionHolder Table::ensureTransaction() {
		// If this table is part of a Database, we want to start a global (database-wide known) transaction. Otherwise,
		// we'll have to be content with a transaction only locally known.
		return m_database ? m_database->ensureTransaction() : TransactionHolder(m_sql, true);
	}

#define THROW_FORMATERROR(msg) throw FormatException(std::string("JSON-Import (table \"") + m_name + "\"): " + msg)
	void Table::importFromJSON(const nlohmann::json &json, bool create) {
		assert(!m_name.empty());

		if (!json.is_object()) {
			THROW_FORMATERROR("Expected table to represented as a single JSON object");
		}
		// Validate that the expected fields are present and of the expected type
		std::vector< std::pair< std::string, nlohmann::json::value_t > > expectedFields = {
			{ "column_names", nlohmann::json::value_t::array },
			{ "column_types", nlohmann::json::value_t::array },
			{ "rows", nlohmann::json::value_t::array }
		};
		for (const std::pair< std::string, nlohmann::json::value_t > &currentPair : expectedFields) {
			if (!json.contains(currentPair.first)) {
				THROW_FORMATERROR("Table specification is missing the \"" + currentPair.first + "\" field");
			}
			if (json[currentPair.first].type() != currentPair.second) {
				THROW_FORMATERROR("Field \"" + currentPair.first + "\" is of the wrong type");
			}
		}
		// Validate that there are no extra fields
		if (json.size() > expectedFields.size()) {
			THROW_FORMATERROR("Table spec is expected to contain only " + std::to_string(expectedFields.size())
							  + " but contained " + std::to_string(json.size()));
		}

		const nlohmann::json &colNames = json["column_names"];
		const nlohmann::json &colTypes = json["column_types"];
		const nlohmann::json &rows     = json["rows"];

		// Some more validations
		if (colNames.size() != colTypes.size()) {
			THROW_FORMATERROR("Amount of column names (" + std::to_string(colNames.size())
							  + " does not match column types (" + std::to_string(colTypes.size()) + ")");
		}
		for (std::size_t i = 0; i < colNames.size(); ++i) {
			if (!colNames[i].is_string()) {
				THROW_FORMATERROR("Encountered non-string column name specification at position "
								  + std::to_string(i + 1));
			}
			if (!colTypes[i].is_string()) {
				THROW_FORMATERROR("Encountered non-string column type specification at position "
								  + std::to_string(i + 1));
			}
			if (boost::contains(colNames[i].get< std::string >(), " ")) {
				THROW_FORMATERROR("Invalid column name \"" + colNames[i].get< std::string >() + "\"");
			}
			try {
				// Check if we can convert the given string to a known data type
				DataType::fromSQLRepresentation(colTypes[i].get< std::string >());
			} catch (const UnknownDataTypeException &e) {
				THROW_FORMATERROR("Unknown column type \"" + colTypes[i].get< std::string >() + "\" for column \""
								  + colNames[i].get< std::string >() + "\": " + e.what());
			}
		}
		for (std::size_t i = 0; i < rows.size(); ++i) {
			const nlohmann::json &currentRow = rows.at(i);

			if (!currentRow.is_array()) {
				THROW_FORMATERROR("Row entry " + std::to_string(i + 1) + " is not of type array");
			}
			if (currentRow.size() != colNames.size()) {
				THROW_FORMATERROR("Row " + std::to_string(i + 1) + " contains " + std::to_string(currentRow.size())
								  + " entries, but " + std::to_string(colNames.size()) + " were expected");
			}
		}

		if (!m_columns.empty()) {
			// Make sure that the specified columns and types match with our stored specification
			if (m_columns.size() != colNames.size()) {
				THROW_FORMATERROR("Attempted to import " + std::to_string(colNames.size())
								  + " into a pre-defined table that only contains " + std::to_string(m_columns.size())
								  + " columns");
			}
			for (std::size_t i = 0; i < m_columns.size(); ++i) {
				std::string currentName = colNames[i].get< std::string >();
				const Column *col       = findColumn(currentName);

				if (!col) {
					THROW_FORMATERROR("A column with the name \"" + currentName
									  + "\" is not part of the pre-defined columns for this table");
				}
				if (DataType::fromSQLRepresentation(colTypes[i].get< std::string >()) != col->getType()) {
					THROW_FORMATERROR("Column type mismatch for column \"" + currentName + "\": Expected: \""
									  + col->getType().sqlRepresentation(m_backend) + "\", got \""
									  + colTypes[i].get< std::string >() + "\"");
				}
			}
		} else {
			// Import columns as specified
			m_columns.resize(colNames.size());
			for (std::size_t i = 0; i < colNames.size(); ++i) {
				Column column(colNames[i].get< std::string >(),
							  DataType::fromSQLRepresentation(colTypes[i].get< std::string >()));

				m_columns[i] = std::move(column);
			}
		}

		if (create) {
			// Now we have all information together that we need in order to create the table
			this->create();
		}

		// From this point on we are assuming that the table represented by this object actually exists in the
		// respective database, so we can now start inserting the provided data into it.
		std::string query            = "INSERT INTO \"" + m_name + "\" (";
		std::string valuePlaceholder = "";
		for (std::size_t i = 0; i < colNames.size(); ++i) {
			query += "\"" + colNames[i].get< std::string >() + "\"";

			std::string placeholder = ":" + colNames[i].get< std::string >();

			if (m_backend == Backend::PostgreSQL && getColumns()[i].getType() == DataType::Binary) {
				// Special case: we have to write the data insertion directly into the query (see binary data handling
				// further below for why)
				valuePlaceholder += "DECODE(" + placeholder + ", 'hex')";
			} else {
				valuePlaceholder += placeholder;
			}

			if (i + 1 < colNames.size()) {
				query += ", ";
				valuePlaceholder += ", ";
			}
		}
		query += ") VALUES(" + valuePlaceholder + ")";

		TransactionHolder transaction = ensureTransaction();

		soci::statement stmt = m_sql.prepare << query;

		std::vector< std::string > values;
		std::vector< soci::blob > binaryValues;
		std::vector< soci::indicator > indicators;

		values.reserve(colNames.size());
		binaryValues.reserve(colNames.size());
		indicators.reserve(colNames.size());
		for (const nlohmann::json &currentRow : rows) {
			assert(currentRow.size() == colNames.size());

			// We have to first transfer our values into the values vector in order to guarantee that they
			// are not destroyed in the middle of the DB statement (which might happen, if we were to use
			// the temporaries directly)
			for (std::size_t i = 0; i < currentRow.size(); ++i) {
				const nlohmann::json &currentVal = currentRow[i];

				if (currentVal.is_null()) {
					values.push_back({});
					indicators.push_back(soci::i_null);

					stmt.exchange(soci::use(values[values.size() - 1], indicators[indicators.size() - 1]));

					continue;
				}

				if (getColumns()[i].getType() == DataType::Blob || getColumns()[i].getType() == DataType::Binary) {
					// We have to handle binary data special in order to prevent SOCI suggesting to the DB backend that
					// the hex representation (which we expect here) is in fact to be interpreted as a string (which
					// would lead to various undesired behavior depending on the used backend)
					if (getColumns()[i].getType() == DataType::Binary && m_backend == Backend::PostgreSQL) {
						// In PostgreSQL we can't use a BLOB for inserting into a BYTEA column (at leas not via SOCI) as
						// that'd insert the BLOB's OID instead of its contents into the column.
						std::string hexString = utils::to_string(currentVal);
						if (hexString.size() >= 2 && hexString.substr(0, 2) == "0x") {
							hexString = hexString.substr(2);
						}
						values.push_back(std::move(hexString));
						stmt.exchange(soci::use(values.back()));
					} else {
						binaryValues.push_back(soci::blob{ m_sql });
						// Convert hex representation to binary values
						std::vector< std::uint8_t > binary =
							utils::hexToBinary< decltype(binary) >(currentVal.get< std::string >());
						// Write the binary data into a BLOB object, which can be bound to our statement
						binaryValues.back().write_from_start(reinterpret_cast< const char * >(binary.data()),
															 binary.size());

						stmt.exchange(soci::use(binaryValues.back()));
					}
				} else {
					values.push_back(utils::to_string(currentVal));

					stmt.exchange(soci::use(values[values.size() - 1]));
				}
			}

			stmt.define_and_bind();
			stmt.execute(true);
			stmt.bind_clean_up();

			values.clear();
			binaryValues.clear();
			indicators.clear();
		}

		transaction.commit();
	}
#undef THROW_FORMATERROR

	nlohmann::json Table::exportToJSON() {
		assert(!m_columns.empty());
		assert(!m_name.empty());

		TransactionHolder transaction = ensureTransaction();

		nlohmann::json json;

		std::string query = "SELECT ";
		for (const Column &currentColumn : m_columns) {
			json["column_names"].push_back(currentColumn.getName());
			json["column_types"].push_back(currentColumn.getType().sqlRepresentation(m_backend));

			query += "\"" + currentColumn.getName() + "\", ";
		}
		// Remove trailing ", "
		query.erase(query.size() - 2);

		query += " FROM \"" + m_name + "\"";

		nlohmann::json rows = nlohmann::json::array_t();

		try {
			soci::rowset< soci::row > rowSet = m_sql.prepare << query;

			for (auto it = rowSet.begin(); it != rowSet.end(); ++it) {
				const soci::row &currentRow = *it;

				nlohmann::json jsonRow = nlohmann::json::array_t();
				for (std::size_t i = 0; i < currentRow.size(); ++i) {
					jsonRow.push_back(utils::to_json(currentRow, i));
				}

				rows.push_back(std::move(jsonRow));
			}

			json["rows"] = std::move(rows);
		} catch (const soci::soci_error &e) {
			throw AccessException(e.what());
		}

		transaction.commit();

		return json;
	}

	void Table::performCtorAssertions() {
		// Names with spaces are not allowed as these cause issues
		assert(!boost::contains(m_name, " "));
#ifndef NDEBUG
		for (const Column &currentColumn : m_columns) {
			assert(!boost::contains(currentColumn.getName(), " "));
		}
#endif

		// We reserve the name for a table's backup (needed during migrations) right from the start
		assert(!boost::ends_with(m_name, Table::BACKUP_SUFFIX));
	}

} // namespace db
} // namespace mumble
